"""
Phase 3: Mechanisms & Deeper Analysis — Innovation/R&D
======================================================
- Cointegration tests (Kao)
- Bootstrap standard errors for key results
- Placebo / permutation tests
- Leave-one-out country analysis
- Regional jackknife
- Mediation: Z → working_age_share → patents (labor force channel)
- Z x kaopen interactions
- Human capital channel
- Structural break analysis (pre/post GFC)
- R&D vs patents puzzle (efficiency decline)
- Patent quality: nonres_patent_share

Output: output/tables/mechanisms.md, cointegration.md, bootstrap.md,
        placebo.md, leave_one_out.md, structural_break.md,
        mediation.md, interactions.md, efficiency.md
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd
from scipy import stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/innovation")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"
TABLES_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

# Region mapping by ISO3 prefix / lists
REGION_MAP = {
    'East Asia & Pacific': [
        'AUS', 'BRN', 'CHN', 'FJI', 'HKG', 'IDN', 'JPN', 'KHM', 'KIR',
        'KOR', 'LAO', 'MAC', 'MHL', 'MMR', 'MNG', 'MYS', 'NZL', 'PHL',
        'PLW', 'PNG', 'SGP', 'SLB', 'THA', 'TLS', 'TON', 'TUV', 'TWN',
        'VNM', 'VUT', 'WSM',
    ],
    'Europe & Central Asia': [
        'ALB', 'AND', 'ARM', 'AUT', 'AZE', 'BEL', 'BGR', 'BIH', 'BLR',
        'CHE', 'CYP', 'CZE', 'DEU', 'DNK', 'ESP', 'EST', 'FIN', 'FRA',
        'GBR', 'GEO', 'GRC', 'HRV', 'HUN', 'IRL', 'ISL', 'ITA', 'KAZ',
        'KGZ', 'LTU', 'LUX', 'LVA', 'MDA', 'MKD', 'MLT', 'MNE', 'NLD',
        'NOR', 'POL', 'PRT', 'ROU', 'RUS', 'SMR', 'SRB', 'SVK', 'SVN',
        'SWE', 'TJK', 'TKM', 'TUR', 'UKR', 'UZB', 'XKX',
    ],
    'Latin America & Caribbean': [
        'ARG', 'ATG', 'BHS', 'BLZ', 'BOL', 'BRA', 'BRB', 'CHL', 'COL',
        'CRI', 'CUB', 'DMA', 'DOM', 'ECU', 'GRD', 'GTM', 'GUY', 'HND',
        'HTI', 'JAM', 'KNA', 'LCA', 'MEX', 'NIC', 'PAN', 'PER', 'PRY',
        'SLV', 'SUR', 'TTO', 'URY', 'VCT', 'VEN',
    ],
    'Middle East & North Africa': [
        'ARE', 'BHR', 'DJI', 'DZA', 'EGY', 'IRN', 'IRQ', 'ISR', 'JOR',
        'KWT', 'LBN', 'LBY', 'MAR', 'MLI', 'OMN', 'PSE', 'QAT', 'SAU',
        'SYR', 'TUN', 'YEM',
    ],
    'South Asia': [
        'AFG', 'BGD', 'BTN', 'IND', 'LKA', 'MDV', 'NPL', 'PAK',
    ],
    'Sub-Saharan Africa': [
        'AGO', 'BDI', 'BEN', 'BFA', 'BWA', 'CAF', 'CIV', 'CMR', 'COD',
        'COG', 'COM', 'CPV', 'ERI', 'ETH', 'GAB', 'GHA', 'GIN', 'GMB',
        'GNB', 'GNQ', 'KEN', 'LBR', 'LSO', 'MDG', 'MOZ', 'MRT', 'MUS',
        'MWI', 'NAM', 'NER', 'NGA', 'RWA', 'SEN', 'SLE', 'SOM', 'SSD',
        'STP', 'SWZ', 'SYC', 'TCD', 'TGO', 'TZA', 'UGA', 'ZAF', 'ZMB',
        'ZWE',
    ],
    'North America': ['CAN', 'USA'],
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def get_region(iso3):
    for region, countries in REGION_MAP.items():
        if iso3 in countries:
            return region
    return 'Other'


def run_model(df, dep_var, regressors, label, silent=False):
    """Run PanelGLS and return results dict."""
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        return None
    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        if not silent:
            print(f"  [{label}] Insufficient obs ({len(sub)}) -- skipping")
        return None
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)
    if not silent:
        print(f"  [{label}]  N={gls.n_obs}, R2={gls.r_squared:.4f}")
    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared, 'model': gls,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
    return results


def save_table(rows, filename, title, col_headers=None):
    """Save list-of-dicts as markdown table."""
    if not rows:
        return
    if col_headers is None:
        col_headers = list(rows[0].keys())
    md = [f"# {title}\n"]
    md.append("| " + " | ".join(col_headers) + " |")
    md.append("|" + "|".join(["---"] * len(col_headers)) + "|")
    for row in rows:
        vals = []
        for h in col_headers:
            v = row.get(h, '')
            if isinstance(v, float):
                vals.append(f"{v:.4f}")
            else:
                vals.append(str(v))
        md.append("| " + " | ".join(vals) + " |")
    out = TABLES_DIR / filename
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


# ═══════════════════════════════════════════════════════════════════════
# 1. COINTEGRATION TESTS (Kao test)
# ═══════════════════════════════════════════════════════════════════════

def run_cointegration(df):
    """
    Kao (1999) panel cointegration test.
    H0: no cointegration. Reject H0 → variables are cointegrated.
    Uses ADF test on panel residuals from the level regression.
    """
    print("\n" + "=" * 70)
    print("1. COINTEGRATION TESTS (Kao)")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]

    dep_vars = [
        ('rd_expenditure_gdp', 'R&D/GDP'),
        ('log_patents', 'log(patents)'),
        ('hightech_exports_share', 'High-tech exports %'),
        ('patents_per_million', 'Patents/million'),
    ]

    rows = []
    for dep, label in dep_vars:
        regs = demo_vars + controls
        regs = [r for r in regs if r in df.columns]
        sub = df.dropna(subset=[dep] + regs).copy()
        if len(sub) < 100:
            continue

        # Run level regression to get residuals
        gls = PanelGLS()
        gls.fit(sub[dep].values, sub[regs].values,
                sub['iso3'].values, sub['year'].values)

        # ADF on residuals by entity
        resid = gls.resid
        entities = gls.entity_ids
        times = gls.time_ids

        adf_stats = []
        for entity in np.unique(entities):
            mask = entities == entity
            e_resid = resid[mask]
            e_times = times[mask]
            order = np.argsort(e_times)
            e_resid = e_resid[order]

            if len(e_resid) < 5:
                continue

            # ADF(1) on residuals: delta_e = alpha * e_{t-1} + eps
            de = np.diff(e_resid)
            e_lag = e_resid[:-1]
            if len(de) < 3:
                continue
            slope, intercept, r_val, p_val, std_err = stats.linregress(e_lag, de)
            if std_err > 0:
                adf_stats.append(slope / std_err)

        if len(adf_stats) < 5:
            continue

        # Kao test statistic: average of individual ADF t-stats,
        # normalized by sqrt(N)
        mean_t = np.mean(adf_stats)
        n_entities = len(adf_stats)
        # Under H0 (no cointegration), mean ADF ~ N(0, 1/sqrt(N))
        kao_stat = mean_t * np.sqrt(n_entities)
        kao_p = 2 * stats.norm.cdf(-abs(kao_stat))

        print(f"  {label}: Kao stat = {kao_stat:.3f}, p = {kao_p:.4f}, "
              f"N_entities = {n_entities}")
        rows.append({
            'Dep Var': label,
            'Kao Statistic': kao_stat,
            'p-value': kao_p,
            'Sig': stars(kao_p),
            'N entities': n_entities,
            'N obs': gls.n_obs,
        })

    save_table(rows, "cointegration.md", "Kao Panel Cointegration Tests")
    return rows


# ═══════════════════════════════════════════════════════════════════════
# 2. BOOTSTRAP STANDARD ERRORS
# ═══════════════════════════════════════════════════════════════════════

def run_bootstrap(df, n_boot=500):
    """
    Bootstrap (by country cluster) for patent and high-tech results.
    """
    print("\n" + "=" * 70)
    print("2. BOOTSTRAP STANDARD ERRORS")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]
    regs = demo_vars + controls

    specs = [
        ('log_patents', 'log(patents)'),
        ('hightech_exports_share', 'High-tech exports %'),
        ('patents_per_million', 'Patents/million'),
        ('rd_expenditure_gdp', 'R&D/GDP'),
    ]

    rows = []
    for dep, label in specs:
        sub = df.dropna(subset=[dep] + regs).copy()
        if len(sub) < 100:
            continue

        countries = sub['iso3'].unique()
        n_countries = len(countries)

        # Original estimate
        gls = PanelGLS()
        gls.fit(sub[dep].values, sub[regs].values,
                sub['iso3'].values, sub['year'].values)
        orig_betas = gls.beta.copy()
        orig_se = gls.se.copy()

        print(f"  Bootstrapping {label} ({n_boot} iterations) ...")
        boot_betas = []
        np.random.seed(42)

        for b in range(n_boot):
            # Cluster bootstrap: resample countries with replacement
            boot_countries = np.random.choice(countries, size=n_countries,
                                               replace=True)
            boot_dfs = []
            for i, c in enumerate(boot_countries):
                cdf = sub[sub['iso3'] == c].copy()
                cdf['iso3'] = f"{c}_{i}"  # unique ID for duplicates
                boot_dfs.append(cdf)
            boot_df = pd.concat(boot_dfs, ignore_index=True)

            try:
                bg = PanelGLS()
                bg.fit(boot_df[dep].values, boot_df[regs].values,
                       boot_df['iso3'].values, boot_df['year'].values)
                boot_betas.append(bg.beta.copy())
            except Exception:
                pass

        if len(boot_betas) < 100:
            print(f"    Too few successful bootstraps ({len(boot_betas)})")
            continue

        boot_arr = np.array(boot_betas)
        boot_se = np.std(boot_arr, axis=0)
        boot_ci_lo = np.percentile(boot_arr, 2.5, axis=0)
        boot_ci_hi = np.percentile(boot_arr, 97.5, axis=0)

        print(f"    {len(boot_betas)} successful bootstraps")

        for i, var in enumerate(regs):
            if var not in demo_vars:
                continue
            boot_p = 2 * np.mean(
                (boot_arr[:, i] > 0) if orig_betas[i] < 0
                else (boot_arr[:, i] < 0)
            )
            rows.append({
                'Dep Var': label,
                'Variable': var,
                'Original Coef': orig_betas[i],
                'Original SE': orig_se[i],
                'Bootstrap SE': boot_se[i],
                'Boot 2.5%': boot_ci_lo[i],
                'Boot 97.5%': boot_ci_hi[i],
                'Boot p': boot_p,
                'Sig': stars(boot_p),
            })

    save_table(rows, "bootstrap.md", "Bootstrap Standard Errors (Cluster by Country)")
    return rows


# ═══════════════════════════════════════════════════════════════════════
# 3. PLACEBO / PERMUTATION TEST
# ═══════════════════════════════════════════════════════════════════════

def run_placebo(df, n_perm=1000):
    """
    Permutation test: randomly shuffle Z_1 across countries within each year,
    re-estimate. If real Z_1 coefficient is more extreme than permuted
    distribution, the relationship is unlikely spurious.
    """
    print("\n" + "=" * 70)
    print("3. PLACEBO / PERMUTATION TEST")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]
    regs = demo_vars + controls

    specs = [
        ('log_patents', 'log(patents)'),
        ('hightech_exports_share', 'High-tech exports %'),
    ]

    rows = []
    for dep, label in specs:
        sub = df.dropna(subset=[dep] + regs).copy()
        if len(sub) < 100:
            continue

        # Real coefficient
        gls = PanelGLS()
        gls.fit(sub[dep].values, sub[regs].values,
                sub['iso3'].values, sub['year'].values)
        real_z1 = gls.beta[0]  # Z_1 is first

        print(f"  Permuting {label} ({n_perm} iterations) ...")
        print(f"    Real Z_1 coef: {real_z1:.4f}")

        perm_coefs = []
        np.random.seed(123)

        for p in range(n_perm):
            perm_df = sub.copy()
            # Shuffle Z variables across countries within year
            for yr in perm_df['year'].unique():
                yr_mask = perm_df['year'] == yr
                n_yr = yr_mask.sum()
                for zv in demo_vars:
                    vals = perm_df.loc[yr_mask, zv].values.copy()
                    np.random.shuffle(vals)
                    perm_df.loc[yr_mask, zv] = vals

            try:
                pg = PanelGLS()
                pg.fit(perm_df[dep].values, perm_df[regs].values,
                       perm_df['iso3'].values, perm_df['year'].values)
                perm_coefs.append(pg.beta[0])
            except Exception:
                pass

        if len(perm_coefs) < 100:
            continue

        perm_arr = np.array(perm_coefs)
        perm_p = np.mean(np.abs(perm_arr) >= np.abs(real_z1))
        print(f"    Permutation p-value: {perm_p:.4f}")

        rows.append({
            'Dep Var': label,
            'Real Z_1 Coef': real_z1,
            'Perm Mean': np.mean(perm_arr),
            'Perm SD': np.std(perm_arr),
            'Perm p-value': perm_p,
            'Sig': stars(perm_p),
            'N perms': len(perm_coefs),
        })

    save_table(rows, "placebo.md", "Placebo / Permutation Tests (Z_1)")
    return rows


# ═══════════════════════════════════════════════════════════════════════
# 4. LEAVE-ONE-OUT COUNTRY ANALYSIS
# ═══════════════════════════════════════════════════════════════════════

def run_leave_one_out(df):
    """
    Drop each country one at a time and re-estimate.
    Identifies whether any single country drives results.
    """
    print("\n" + "=" * 70)
    print("4. LEAVE-ONE-OUT COUNTRY ANALYSIS")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]
    regs = demo_vars + controls

    specs = [
        ('log_patents', 'log(patents)'),
        ('hightech_exports_share', 'High-tech exports %'),
    ]

    all_rows = []
    for dep, label in specs:
        sub = df.dropna(subset=[dep] + regs).copy()
        if len(sub) < 100:
            continue

        # Full sample estimate
        gls = PanelGLS()
        gls.fit(sub[dep].values, sub[regs].values,
                sub['iso3'].values, sub['year'].values)
        full_z1 = gls.beta[0]
        full_p = gls.pvalues[0]

        countries = sorted(sub['iso3'].unique())
        print(f"  {label}: LOO across {len(countries)} countries (full Z_1 = {full_z1:.4f})")

        loo_z1 = []
        loo_p = []
        influential = []

        for c in countries:
            loo_df = sub[sub['iso3'] != c].copy()
            if len(loo_df) < 50:
                continue
            try:
                lg = PanelGLS()
                lg.fit(loo_df[dep].values, loo_df[regs].values,
                       loo_df['iso3'].values, loo_df['year'].values)
                loo_z1.append(lg.beta[0])
                loo_p.append(lg.pvalues[0])

                # Flag if dropping this country flips significance
                if (full_p < 0.10 and lg.pvalues[0] >= 0.10):
                    influential.append(c)
                    print(f"    Dropping {c}: Z_1 = {lg.beta[0]:.4f} (p={lg.pvalues[0]:.4f}) -- INFLUENTIAL")
            except Exception:
                pass

        loo_arr = np.array(loo_z1)
        all_rows.append({
            'Dep Var': label,
            'Full Z_1': full_z1,
            'Full p': full_p,
            'LOO Mean Z_1': np.mean(loo_arr),
            'LOO Min Z_1': np.min(loo_arr),
            'LOO Max Z_1': np.max(loo_arr),
            'LOO SD': np.std(loo_arr),
            'N flips': len(influential),
            'Influential': ', '.join(influential) if influential else 'None',
        })

    save_table(all_rows, "leave_one_out.md",
               "Leave-One-Out Country Analysis")
    return all_rows


# ═══════════════════════════════════════════════════════════════════════
# 5. REGIONAL JACKKNIFE
# ═══════════════════════════════════════════════════════════════════════

def run_regional_jackknife(df):
    """Drop one region at a time and re-estimate."""
    print("\n" + "=" * 70)
    print("5. REGIONAL JACKKNIFE")
    print("=" * 70)

    df = df.copy()
    df['region'] = df['iso3'].apply(get_region)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]
    regs = demo_vars + controls

    specs = [
        ('log_patents', 'log(patents)'),
        ('hightech_exports_share', 'High-tech exports %'),
        ('rd_expenditure_gdp', 'R&D/GDP'),
    ]

    rows = []
    for dep, label in specs:
        sub = df.dropna(subset=[dep] + regs).copy()
        if len(sub) < 100:
            continue

        # Full sample
        gls = PanelGLS()
        gls.fit(sub[dep].values, sub[regs].values,
                sub['iso3'].values, sub['year'].values)
        full_z1 = gls.beta[0]

        regions = sorted(sub['region'].unique())
        for region in regions:
            jk_df = sub[sub['region'] != region].copy()
            if len(jk_df) < 50 or jk_df['iso3'].nunique() < 5:
                continue
            try:
                jg = PanelGLS()
                jg.fit(jk_df[dep].values, jk_df[regs].values,
                       jk_df['iso3'].values, jk_df['year'].values)
                n_dropped = sub[sub['region'] == region]['iso3'].nunique()
                rows.append({
                    'Dep Var': label,
                    'Dropped Region': region,
                    'N dropped': n_dropped,
                    'Z_1 Coef': jg.beta[0],
                    'Z_1 SE': jg.se[0],
                    'Z_1 p': jg.pvalues[0],
                    'Sig': stars(jg.pvalues[0]),
                    'R2': jg.r_squared,
                })
            except Exception:
                pass

        # Add full sample row
        rows.append({
            'Dep Var': label,
            'Dropped Region': 'FULL SAMPLE',
            'N dropped': 0,
            'Z_1 Coef': full_z1,
            'Z_1 SE': gls.se[0],
            'Z_1 p': gls.pvalues[0],
            'Sig': stars(gls.pvalues[0]),
            'R2': gls.r_squared,
        })

    save_table(rows, "regional_jackknife.md", "Regional Jackknife Analysis")
    return rows


# ═══════════════════════════════════════════════════════════════════════
# 6. MEDIATION ANALYSIS: Z → working_age_share → patents
# ═══════════════════════════════════════════════════════════════════════

def run_mediation(df):
    """
    Test the labor force channel:
    Z_1 → working_age_share → log_patents

    Baron-Kenny mediation:
    (a) Z_1 → working_age_share
    (b) working_age_share → log_patents (controlling for Z_1)
    (c) Z_1 → log_patents (total effect)
    (c') Z_1 → log_patents controlling for working_age_share (direct effect)
    Indirect effect = c - c' (or a * b)
    """
    print("\n" + "=" * 70)
    print("6. MEDIATION ANALYSIS")
    print("=" * 70)

    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]

    mediators = []
    if 'working_age_share' in df.columns:
        mediators.append(('working_age_share', 'Working-age share'))
    if 'human_capital' in df.columns:
        mediators.append(('human_capital', 'Human capital index'))

    dep_vars = [
        ('log_patents', 'log(patents)'),
        ('hightech_exports_share', 'High-tech exports %'),
    ]

    rows = []
    for dep, dep_label in dep_vars:
        for med_var, med_label in mediators:
            demo_vars = ['Z_1', 'Z_2', 'Z_3']
            regs_base = demo_vars + controls
            regs_med = demo_vars + [med_var] + controls

            sub = df.dropna(subset=[dep, med_var] + regs_base).copy()
            if len(sub) < 80:
                continue

            # Path a: Z_1 → mediator
            r_a = run_model(sub, med_var, demo_vars + controls,
                            f"a: Z -> {med_label}", silent=True)
            if r_a is None:
                continue
            a_coef = r_a['coef_Z_1']
            a_p = r_a['p_Z_1']

            # Path c: total effect Z_1 → dep
            r_c = run_model(sub, dep, regs_base,
                            f"c: Z -> {dep_label}", silent=True)
            if r_c is None:
                continue
            c_coef = r_c['coef_Z_1']
            c_p = r_c['p_Z_1']

            # Path c': direct effect Z_1 → dep, controlling for mediator
            r_cp = run_model(sub, dep, regs_med,
                             f"c': Z -> {dep_label} | {med_label}", silent=True)
            if r_cp is None:
                continue
            cp_coef = r_cp['coef_Z_1']
            cp_p = r_cp['p_Z_1']

            # Path b: mediator → dep (from path c' regression)
            b_coef = r_cp[f'coef_{med_var}']
            b_p = r_cp[f'p_{med_var}']

            indirect = c_coef - cp_coef
            if abs(c_coef) > 1e-10:
                pct_mediated = (indirect / c_coef) * 100
            else:
                pct_mediated = 0.0

            print(f"  {dep_label} via {med_label}:")
            print(f"    a (Z1 -> med) = {a_coef:.4f} (p={a_p:.4f})")
            print(f"    b (med -> dep) = {b_coef:.4f} (p={b_p:.4f})")
            print(f"    c (total)     = {c_coef:.4f} (p={c_p:.4f})")
            print(f"    c' (direct)   = {cp_coef:.4f} (p={cp_p:.4f})")
            print(f"    Indirect      = {indirect:.4f} ({pct_mediated:.1f}%)")

            rows.append({
                'Dep Var': dep_label,
                'Mediator': med_label,
                'Path a (Z1->M)': a_coef,
                'a p-value': a_p,
                'Path b (M->Y)': b_coef,
                'b p-value': b_p,
                'Total (c)': c_coef,
                'Direct (c\')': cp_coef,
                'Indirect': indirect,
                '% Mediated': pct_mediated,
                'N': r_cp['n_obs'],
            })

    save_table(rows, "mediation.md", "Mediation Analysis: Labor Force & Human Capital Channels")
    return rows


# ═══════════════════════════════════════════════════════════════════════
# 7. Z x KAOPEN INTERACTIONS
# ═══════════════════════════════════════════════════════════════════════

def run_interactions(df):
    """
    Test whether capital account openness modulates the
    demographics-innovation relationship.
    Z_1 x kaopen: does financial openness enable innovation-seeking flows?
    """
    print("\n" + "=" * 70)
    print("7. Z x KAOPEN INTERACTIONS")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]

    # Create interaction terms
    df = df.copy()
    for zv in demo_vars:
        df[f'{zv}_x_kaopen'] = df[zv] * df['kaopen']

    int_vars = [f'{zv}_x_kaopen' for zv in demo_vars]
    regs = demo_vars + controls + int_vars

    specs = [
        ('rd_expenditure_gdp', 'R&D/GDP'),
        ('log_patents', 'log(patents)'),
        ('hightech_exports_share', 'High-tech exports %'),
        ('nonres_patent_share', 'Non-resident patent share'),
    ]

    rows = []
    for dep, label in specs:
        r = run_model(df, dep, regs, f"Int: {label}")
        if r is None:
            continue
        for var in int_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                rows.append({
                    'Dep Var': label,
                    'Interaction': var,
                    'Coef': r[ckey],
                    'SE': r[f'se_{var}'],
                    'p-value': r[f'p_{var}'],
                    'Sig': stars(r[f'p_{var}']),
                    'N': r['n_obs'],
                    'R2': r['r_squared'],
                })

    save_table(rows, "interactions.md",
               "Demographic-Capital Openness Interactions")
    return rows


# ═══════════════════════════════════════════════════════════════════════
# 8. STRUCTURAL BREAK ANALYSIS
# ═══════════════════════════════════════════════════════════════════════

def run_structural_break(df):
    """
    Pre/post GFC (2008) structural break.
    Also test pre/post 2000.
    """
    print("\n" + "=" * 70)
    print("8. STRUCTURAL BREAK ANALYSIS")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]
    regs = demo_vars + controls

    specs = [
        ('rd_expenditure_gdp', 'R&D/GDP'),
        ('log_patents', 'log(patents)'),
        ('hightech_exports_share', 'High-tech exports %'),
    ]

    breaks = [
        (2008, 'Pre-GFC', 'Post-GFC'),
        (2000, 'Pre-2000', 'Post-2000'),
    ]

    rows = []
    for dep, dep_label in specs:
        for break_year, pre_label, post_label in breaks:
            # Pre
            pre = df[df['year'] < break_year].copy()
            r_pre = run_model(pre, dep, regs,
                              f"{dep_label}: {pre_label}", silent=True)
            # Post
            post = df[df['year'] >= break_year].copy()
            r_post = run_model(post, dep, regs,
                               f"{dep_label}: {post_label}", silent=True)

            if r_pre and r_post:
                rows.append({
                    'Dep Var': dep_label,
                    'Period': pre_label,
                    'Break Year': break_year,
                    'Z_1 Coef': r_pre.get('coef_Z_1', np.nan),
                    'Z_1 SE': r_pre.get('se_Z_1', np.nan),
                    'Z_1 p': r_pre.get('p_Z_1', np.nan),
                    'Sig': stars(r_pre.get('p_Z_1', 1)),
                    'N': r_pre['n_obs'],
                    'R2': r_pre['r_squared'],
                })
                rows.append({
                    'Dep Var': dep_label,
                    'Period': post_label,
                    'Break Year': break_year,
                    'Z_1 Coef': r_post.get('coef_Z_1', np.nan),
                    'Z_1 SE': r_post.get('se_Z_1', np.nan),
                    'Z_1 p': r_post.get('p_Z_1', np.nan),
                    'Sig': stars(r_post.get('p_Z_1', 1)),
                    'N': r_post['n_obs'],
                    'R2': r_post['r_squared'],
                })

    save_table(rows, "structural_break.md",
               "Structural Break Analysis")
    return rows


# ═══════════════════════════════════════════════════════════════════════
# 9. EFFICIENCY PUZZLE: R&D vs PATENTS
# ═══════════════════════════════════════════════════════════════════════

def run_efficiency_analysis(df):
    """
    Reconciling: aging → more R&D spending but fewer patents.
    Test: R&D efficiency = patents / R&D spending.
    Also test patent quality proxy (nonres_patent_share).
    """
    print("\n" + "=" * 70)
    print("9. EFFICIENCY PUZZLE: R&D vs PATENTS")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    age_vars = ['old_dep', 'youth_dep']
    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]

    df = df.copy()

    # Create efficiency measure: patents per unit R&D
    if 'patents_per_million' in df.columns and 'rd_expenditure_gdp' in df.columns:
        df['rd_efficiency'] = df['patents_per_million'] / df['rd_expenditure_gdp'].clip(lower=0.01)
        df.loc[df['rd_expenditure_gdp'].isna() | df['patents_per_million'].isna(),
               'rd_efficiency'] = np.nan

        # Log efficiency
        df['log_rd_efficiency'] = np.log(df['rd_efficiency'].clip(lower=0.01))
        df.loc[df['rd_efficiency'].isna(), 'log_rd_efficiency'] = np.nan

    rows = []

    # Z → R&D efficiency
    if 'log_rd_efficiency' in df.columns:
        r = run_model(df, 'log_rd_efficiency', demo_vars + controls,
                      "Z -> log(R&D efficiency)")
        if r:
            for var in demo_vars:
                rows.append({
                    'Model': 'Z -> log(patents/R&D)',
                    'Variable': var,
                    'Coef': r[f'coef_{var}'],
                    'SE': r[f'se_{var}'],
                    'p-value': r[f'p_{var}'],
                    'Sig': stars(r[f'p_{var}']),
                    'N': r['n_obs'],
                    'R2': r['r_squared'],
                })

    # Age ratios → R&D efficiency
    if 'log_rd_efficiency' in df.columns:
        r = run_model(df, 'log_rd_efficiency', age_vars + controls,
                      "Age ratios -> log(R&D efficiency)")
        if r:
            for var in age_vars:
                rows.append({
                    'Model': 'Age ratios -> log(patents/R&D)',
                    'Variable': var,
                    'Coef': r[f'coef_{var}'],
                    'SE': r[f'se_{var}'],
                    'p-value': r[f'p_{var}'],
                    'Sig': stars(r[f'p_{var}']),
                    'N': r['n_obs'],
                    'R2': r['r_squared'],
                })

    # Z → nonres_patent_share (patent quality/attraction proxy)
    if 'nonres_patent_share' in df.columns:
        r = run_model(df, 'nonres_patent_share', demo_vars + controls,
                      "Z -> nonres patent share")
        if r:
            for var in demo_vars:
                rows.append({
                    'Model': 'Z -> nonres patent share',
                    'Variable': var,
                    'Coef': r[f'coef_{var}'],
                    'SE': r[f'se_{var}'],
                    'p-value': r[f'p_{var}'],
                    'Sig': stars(r[f'p_{var}']),
                    'N': r['n_obs'],
                    'R2': r['r_squared'],
                })

    # Age ratios → nonres_patent_share
    if 'nonres_patent_share' in df.columns:
        r = run_model(df, 'nonres_patent_share', age_vars + controls,
                      "Age ratios -> nonres patent share")
        if r:
            for var in age_vars:
                rows.append({
                    'Model': 'Age ratios -> nonres patent share',
                    'Variable': var,
                    'Coef': r[f'coef_{var}'],
                    'SE': r[f'se_{var}'],
                    'p-value': r[f'p_{var}'],
                    'Sig': stars(r[f'p_{var}']),
                    'N': r['n_obs'],
                    'R2': r['r_squared'],
                })

    # Human capital channel: Z → human_capital → R&D
    if 'human_capital' in df.columns:
        r = run_model(df, 'rd_expenditure_gdp',
                      demo_vars + ['human_capital'] + controls,
                      "Z + HC -> R&D")
        if r:
            for var in demo_vars + ['human_capital']:
                if f'coef_{var}' in r:
                    rows.append({
                        'Model': 'Z + human_capital -> R&D',
                        'Variable': var,
                        'Coef': r[f'coef_{var}'],
                        'SE': r[f'se_{var}'],
                        'p-value': r[f'p_{var}'],
                        'Sig': stars(r[f'p_{var}']),
                        'N': r['n_obs'],
                        'R2': r['r_squared'],
                    })

    save_table(rows, "efficiency.md",
               "R&D Efficiency & Patent Quality Analysis")
    return rows


# ═══════════════════════════════════════════════════════════════════════
# 10. COMPREHENSIVE MECHANISMS TABLE
# ═══════════════════════════════════════════════════════════════════════

def build_mechanisms_table(df):
    """Consolidated mechanisms table for the paper."""
    print("\n" + "=" * 70)
    print("10. COMPREHENSIVE MECHANISMS TABLE")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    age_vars = ['old_dep', 'youth_dep']
    controls = ['rgdp_growth', 'kaopen']
    controls = [c for c in controls if c in df.columns]

    rows = []

    # A: Working age share channel
    if 'working_age_share' in df.columns:
        r = run_model(df, 'log_patents',
                      demo_vars + ['working_age_share'] + controls,
                      "Mech A: + working_age_share")
        if r:
            rows.append({
                'Mechanism': 'A: Labor force channel',
                'Model': 'log(patents) + working_age_share',
                'Z_1 Coef': r.get('coef_Z_1', np.nan),
                'Z_1 p': r.get('p_Z_1', np.nan),
                'Channel Coef': r.get('coef_working_age_share', np.nan),
                'Channel p': r.get('p_working_age_share', np.nan),
                'N': r['n_obs'],
                'R2': r['r_squared'],
            })

    # B: Human capital channel
    if 'human_capital' in df.columns:
        r = run_model(df, 'log_patents',
                      demo_vars + ['human_capital'] + controls,
                      "Mech B: + human_capital")
        if r:
            rows.append({
                'Mechanism': 'B: Human capital channel',
                'Model': 'log(patents) + human_capital',
                'Z_1 Coef': r.get('coef_Z_1', np.nan),
                'Z_1 p': r.get('p_Z_1', np.nan),
                'Channel Coef': r.get('coef_human_capital', np.nan),
                'Channel p': r.get('p_human_capital', np.nan),
                'N': r['n_obs'],
                'R2': r['r_squared'],
            })

    # C: Income level channel
    r = run_model(df, 'log_patents',
                  age_vars + controls,
                  "Mech C: age ratios only")
    if r:
        rows.append({
            'Mechanism': 'C: Age composition',
            'Model': 'log(patents) ~ old_dep + youth_dep',
            'Z_1 Coef': np.nan,
            'Z_1 p': np.nan,
            'Channel Coef': r.get('coef_old_dep', np.nan),
            'Channel p': r.get('p_old_dep', np.nan),
            'N': r['n_obs'],
            'R2': r['r_squared'],
        })

    save_table(rows, "mechanisms.md", "Mechanism Tests")
    return rows


# ═══════════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════════

def main():
    print("=" * 70)
    print("PHASE 3: Mechanisms & Deeper Analysis — Innovation/R&D")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "innovation_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    # Filter to year <= 2024
    df = df[df['year'] <= 2024].copy()

    # 1. Cointegration
    run_cointegration(df)

    # 2. Bootstrap (reduced iterations for speed)
    run_bootstrap(df, n_boot=500)

    # 3. Placebo
    run_placebo(df, n_perm=1000)

    # 4. Leave-one-out
    run_leave_one_out(df)

    # 5. Regional jackknife
    run_regional_jackknife(df)

    # 6. Mediation
    run_mediation(df)

    # 7. Interactions
    run_interactions(df)

    # 8. Structural break
    run_structural_break(df)

    # 9. Efficiency puzzle
    run_efficiency_analysis(df)

    # 10. Mechanisms summary
    build_mechanisms_table(df)

    print("\n" + "=" * 70)
    print("Phase 3 complete. All tables saved to output/tables/")
    print("=" * 70)


if __name__ == "__main__":
    main()
